{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import mindspore\n",
    "from mindspore.dataset import text, GeneratorDataset, transforms\n",
    "from mindspore import nn, context\n",
    "from mindnlp.engine.callbacks import CheckpointCallback\n",
    "from mindnlp.transforms import PadTransform\n",
    "from mindnlp.models import BertModel, BertConfig\n",
    "from mindnlp.transforms.tokenizers import BertTokenizer\n",
    "from mindnlp.engine import Trainer, Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set to GPU mode\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")\n",
    "\n",
    "# prepare dataset\n",
    "class SentimentDataset:\n",
    "    \"\"\"Sentiment Dataset\"\"\"\n",
    "\n",
    "    def __init__(self, path):\n",
    "        self.path = path\n",
    "        self._labels, self._text_a = [], []\n",
    "        self._load()\n",
    "\n",
    "    def _load(self):\n",
    "        with open(self.path, \"r\", encoding=\"utf-8\") as f:\n",
    "            dataset = f.read()\n",
    "        lines = dataset.split(\"\\n\")\n",
    "        for line in lines[1:-1]:\n",
    "            label, text_a = line.split(\"\\t\")\n",
    "            self._labels.append(int(label))\n",
    "            self._text_a.append(text_a)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self._labels[index], self._text_a[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._labels)\n",
    "\n",
    "# please download the dataset from \"https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz\"\n",
    "# and extract it to the \"data\" folder\n",
    "column_names = [\"label\", \"text_a\"]\n",
    "dataset_train = GeneratorDataset(source=SentimentDataset(\"data/train.tsv\"),\n",
    "                                 column_names=column_names, shuffle=True)\n",
    "dataset_val = GeneratorDataset(source=SentimentDataset(\"data/dev.tsv\"),\n",
    "                               column_names=column_names, shuffle=True)\n",
    "dataset_test = GeneratorDataset(source=SentimentDataset(\"data/test.tsv\"),\n",
    "                                column_names=column_names, shuffle=False)\n",
    "\n",
    "vocab_path = os.path.join(\"data\", \"vocab.txt\")\n",
    "vocab = text.Vocab.from_file(vocab_path)\n",
    "vocab_size = len(vocab.vocab())\n",
    "\n",
    "pad_value_text = vocab.tokens_to_ids('[PAD]')\n",
    "tokenizer = BertTokenizer(vocab=vocab)\n",
    "pad_op_text = PadTransform(max_length=64, pad_value=pad_value_text)\n",
    "type_cast_op = transforms.TypeCast(mindspore.int32)\n",
    "\n",
    "dataset_train = dataset_train.map(\n",
    "    operations=[tokenizer, pad_op_text], input_columns=\"text_a\")\n",
    "dataset_train = dataset_train.map(\n",
    "    operations=[type_cast_op], input_columns=\"label\")\n",
    "\n",
    "dataset_val = dataset_val.map(\n",
    "    operations=[tokenizer, pad_op_text], input_columns=\"text_a\")\n",
    "dataset_val = dataset_val.map(operations=[type_cast_op], input_columns=\"label\")\n",
    "\n",
    "# rename the columns because the model's construct function requires the parameter input_ids\n",
    "rename_columns = [\"label\", \"input_ids\"]\n",
    "dataset_train = dataset_train.rename(\n",
    "    input_columns=column_names, output_columns=rename_columns)\n",
    "dataset_val = dataset_val.rename(\n",
    "    input_columns=column_names, output_columns=rename_columns)\n",
    "\n",
    "dataset_train = dataset_train.batch(32)\n",
    "dataset_val = dataset_val.batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model\n",
    "class BertForSequenceClassification(nn.Cell):\n",
    "    \"\"\"Bert Model for classification tasks\"\"\"\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "        self.config = config\n",
    "        self.bert = BertModel(config)\n",
    "        # load the pre-trained parameters\n",
    "        mindspore.load_param_into_net(self.bert, state_dict)\n",
    "        self.classifier = nn.Dense(config.hidden_size, self.num_labels)\n",
    "\n",
    "    def construct(self, input_ids, attention_mask=None, token_type_ids=None,\n",
    "                  position_ids=None, head_mask=None):\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            head_mask=head_mask\n",
    "        )\n",
    "        pooled_output = outputs[1]\n",
    "        logits = self.classifier(pooled_output)\n",
    "        return logits\n",
    "\n",
    "# please download the pre-trained model from \"https://download.mindspore.cn/toolkits/mindnlp/models/bert/bert-base-chinese.ckpt\"\n",
    "# and put it to the \"checkpoints\" folder\n",
    "model_path = os.path.join(\"checkpoints/bert-base-chinese.ckpt\")\n",
    "state_dict = mindspore.load_checkpoint(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set bert config and define parameters for training\n",
    "config = BertConfig(vocab_size=vocab_size, num_labels=3)\n",
    "model_instance = BertForSequenceClassification(config)\n",
    "\n",
    "model_instance.set_train(True)\n",
    "\n",
    "loss = nn.CrossEntropyLoss(ignore_index=pad_value_text)\n",
    "optimizer = nn.Adam(model_instance.trainable_params(), learning_rate=1e-5)\n",
    "\n",
    "metric = Accuracy()\n",
    "\n",
    "# define callbacks to save checkpoints\n",
    "ckpoint_cb = CheckpointCallback(\n",
    "    save_path='sentimentbert_ckpt', epochs=1, keep_checkpoint_max=5)\n",
    "\n",
    "trainer = Trainer(network=model_instance, train_dataset=dataset_train,\n",
    "                  eval_dataset=dataset_val, metrics=metric,\n",
    "                  epochs=5, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The train will start from the checkpoint saved in sentimentbert_ckpt.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 302/302 [01:04<00:00,  4.65it/s, loss=0.3034695] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: BertForSequenceClassification_epoch_0.ckpt has been saved in epoch:0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 34/34 [00:09<00:00,  3.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.7583333333333333}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████| 302/302 [00:44<00:00,  6.82it/s, loss=0.13073428]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: BertForSequenceClassification_epoch_1.ckpt has been saved in epoch:1.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.7916666666666666}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████| 302/302 [00:44<00:00,  6.84it/s, loss=0.08923956] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: BertForSequenceClassification_epoch_2.ckpt has been saved in epoch:2.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.825}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████| 302/302 [00:44<00:00,  6.83it/s, loss=0.06826735] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: BertForSequenceClassification_epoch_3.ckpt has been saved in epoch:3.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 34/34 [00:01<00:00, 24.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.825}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 302/302 [00:44<00:00,  6.81it/s, loss=0.049864933]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: BertForSequenceClassification_epoch_4.ckpt has been saved in epoch:4.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.825}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "trainer.run(tgt_columns=\"label\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ljm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
